import os
os.chdir("../")
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import numpy as np
import torch
import random
import matplotlib.pyplot as plt

from phone_booth_collab_maze import PBCMaze
from models.r2d2 import R2D2Agent
from models.r2d2 import OBLR2D2Agent
from models.r2d2_config import initial_exploration, batch_size, update_target, log_interval, eval_argmax, eval_interval, device, replay_memory_capacity, lr, sequence_length, local_mini_batch
from utils.pbmaze_config import env_config
from pbcmaze_belief_model import ReceiverBeliefModel, SenderBeliefModel
from utils.memory import Memory, LocalBuffer, OBLMemory, OBLLocalBuffer, MIOBLMemory, MIOBLLocalBuffer

from utils.pbmaze_config import iql_env_config as iql_env_config
from utils.pbmaze_config import env_config as obl_env_config

RESULT_PATH = "results/thesis_submission_results/"
MODEL_PATH = "results/thesis_submission_trained_models/"

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

set_seed(0)

env = PBCMaze(env_args=iql_env_config)
env.reset()

a0_input_shape  = env.get_obs_size(0)
a1_input_shape = env.get_obs_size(1)
a0_num_actions = 7
a1_num_actions = 5

receiver_pi_0 = [0.2, 0.2, 0.2, 0.2, 0.2]
sender_pi_0 = [1/7, 1/7, 1/7, 1/7, 1/7, 1/7, 1/7]
rb_model = ReceiverBeliefModel(receiver_pi_0, env)

# Load the models
# obl_mi_a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 0, rb_model)
# obl_mi_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
# obl_mi_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "obl_sender_model _mi_log2_argmax_online_net.pt"))
#
# obl_a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 0, rb_model)
# obl_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
# obl_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "obl_sender_model _argmax_online_net.pt"))
#
# iql_ir_a0_agent = R2D2Agent(a0_input_shape, a0_num_actions, OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device)
# iql_ir_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
# iql_ir_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "iql_sender_model _ir_argmax_online_net.pt"))
#
# iql_a0_agent = R2D2Agent(a0_input_shape, a0_num_actions, OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device)
# iql_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
# iql_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "iql_sender_model _argmax_online_net.pt"))
#
#
# # Look at the policy at the correct phonebooth. First, move the sender and receiver to the phonebooths, then look at each method's policy
# while(env.agent0_loc[0] != env.booth_loc or env.agent1_loc[0] != env.receiver_booth_loc):
#     a0_obs = torch.Tensor(env.get_obs(0)).to(device)
#     _, _, obl_mi_a0_hidden = obl_mi_a0_agent.get_action(a0_obs, obl_mi_a0_hidden, argmax = True)
#     _, _, obl_a0_hidden = obl_a0_agent.get_action(a0_obs, obl_a0_hidden, argmax = True)
#     _, _, iql_ir_a0_hidden = iql_ir_a0_agent.get_action(a0_obs, 0.0, iql_ir_a0_hidden)
#     _, _, iql_a0_hidden = iql_a0_agent.get_action(a0_obs, 0.0, iql_a0_hidden)
#     env.step(0, 1)
#     env.step(1, 0)
#
# env.render()
# print("agent 0 loc: " + str(env.agent0_loc) + " | agent 1 loc: " + str(env.agent1_loc))
#
# a0_obs = torch.Tensor(env.get_obs(0)).to(device)
# obl_mi_a0_policy, _, _ = obl_mi_a0_agent.get_action(a0_obs, obl_mi_a0_hidden, argmax = True)
# obl_a0_policy, _, _ = obl_a0_agent.get_action(a0_obs, obl_a0_hidden, argmax = True)
# iql_ir_a0_policy, _, _ = iql_ir_a0_agent.get_action(a0_obs, 0.0, iql_ir_a0_hidden)
# iql_a0_policy, _, _ = iql_a0_agent.get_action(a0_obs, 0.0, iql_a0_hidden)
#
# print(obl_mi_a0_policy)
# print(obl_a0_policy)
# print(iql_ir_a0_policy)
# print(iql_a0_policy)

# axes[0, 0].bar(bar_names, obl_mi_a0_policy.detach().numpy().squeeze())
# axes[0, 0].set_ylabel("Action Probability")
# axes[0, 0].set_xlabel("Action")
# axes[0, 0].set_title("OBL with MI reward")
#
# axes[0, 1].bar(bar_names, obl_a0_policy.detach().numpy().squeeze())
# axes[0, 1].set_ylabel("Action Probability")
# axes[0, 1].set_xlabel("Action")
# axes[0, 1].set_title("OBL")
#
# axes[1, 0].bar(bar_names, iql_ir_a0_policy.detach().numpy().squeeze())
# axes[1, 0].set_ylabel("Action Probability")
# axes[1, 0].set_xlabel("Action")
# axes[1, 0].set_title("IQL with intermediate reward")
#
# axes[1, 1].bar(bar_names, iql_a0_policy.detach().numpy().squeeze())
# axes[1, 1].set_ylabel("Action Probability")
# axes[1, 1].set_xlabel("Action")
# axes[1, 1].set_title("IQL")


iql_a0_agent = R2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), lr, batch_size, device)
iql_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
iql_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "iql_sender_model _argmax_online_net.pt", map_location = device))

iql_ir_a0_agent = R2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), lr, batch_size, device)
iql_ir_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
iql_ir_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "iql_sender_model _argmax_online_net.pt", map_location = device))

obl_a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 0, rb_model)
obl_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
obl_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "obl_sender_model _argmax_online_net.pt", map_location = device))

obl_ir_a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 0, rb_model)
obl_ir_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
obl_ir_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "obl_sender_model _ir_argmax_online_net.pt", map_location = device))

obl_mi_a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), OBLMemory(replay_memory_capacity), OBLLocalBuffer(), lr, batch_size, device, 0, rb_model)
obl_mi_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
obl_mi_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "obl_sender_model _mi_log2_argmax_online_net.pt", map_location = device))

obl_miloss_a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), MIOBLMemory(replay_memory_capacity), MIOBLLocalBuffer(), lr, batch_size, device, 0, rb_model)
obl_miloss_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
obl_miloss_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "obl_sender_model _mi_loss_argmax_online_net.pt", map_location = device))

obl_mi_miloss_a0_agent = OBLR2D2Agent(a0_input_shape, a0_num_actions, Memory(replay_memory_capacity), LocalBuffer(), MIOBLMemory(replay_memory_capacity), MIOBLLocalBuffer(), lr, batch_size, device, 0, rb_model)
obl_mi_miloss_a0_hidden = (torch.Tensor().new_zeros(1, 1, 16).to(device), torch.Tensor().new_zeros(1, 1, 16).to(device))
obl_mi_miloss_a0_agent.online_net.load_state_dict(torch.load(MODEL_PATH + "obl_sender_model _mi_log2_mi_loss_argmax_online_net.pt", map_location = device))

# Look at the policy at the correct phonebooth. First, move the sender and receiver to the phonebooths, then look at each method's policy
while(env.agent0_loc[0] != env.booth_loc or env.agent1_loc[0] != env.receiver_booth_loc):
    a0_obs = torch.Tensor(env.get_obs(0)).to(device)
    _, _, iql_a0_hidden = iql_a0_agent.get_action(a0_obs, 0.0, iql_a0_hidden)
    _, _, iql_ir_a0_hidden = iql_ir_a0_agent.get_action(a0_obs, 0.0, iql_ir_a0_hidden)
    _, _, obl_a0_hidden = obl_a0_agent.get_action(a0_obs, obl_a0_hidden, argmax = True)
    _, _, obl_ir_a0_hidden = obl_ir_a0_agent.get_action(a0_obs, obl_ir_a0_hidden, argmax = True)
    _, _, obl_mi_a0_hidden = obl_mi_a0_agent.get_action(a0_obs, obl_mi_a0_hidden, argmax = True)
    _, _, obl_miloss_a0_hidden = obl_miloss_a0_agent.get_action(a0_obs, obl_miloss_a0_hidden, argmax = True)
    _, _, obl_mi_miloss_a0_hidden = obl_mi_miloss_a0_agent.get_action(a0_obs, obl_mi_miloss_a0_hidden, argmax = True)
    env.step(0, 1)
    env.step(1, 0)

env.render()
print("agent 0 loc: " + str(env.agent0_loc) + " | agent 1 loc: " + str(env.agent1_loc))

a0_obs = torch.Tensor(env.get_obs(0)).to(device)
iql_a0_policy, _, _ = iql_a0_agent.get_action(a0_obs, 0.0, iql_a0_hidden)
iql_ir_a0_policy, _, _ = iql_ir_a0_agent.get_action(a0_obs, 0.0, iql_ir_a0_hidden)
obl_a0_policy, _, _ = obl_a0_agent.get_action(a0_obs, obl_a0_hidden, argmax = True)
obl_ir_a0_policy, _, _ = obl_ir_a0_agent.get_action(a0_obs, obl_ir_a0_hidden, argmax = True)
obl_mi_a0_policy, _, _ = obl_mi_a0_agent.get_action(a0_obs, obl_mi_a0_hidden, argmax = True)
obl_miloss_a0_policy, _, _ = obl_miloss_a0_agent.get_action(a0_obs, obl_miloss_a0_hidden, argmax = True)
obl_mi_miloss_a0_policy, _, _ = obl_mi_miloss_a0_agent.get_action(a0_obs, obl_mi_miloss_a0_hidden, argmax = True)

# print(obl_a0_policy)
# print(obl_mi_a0_policy)
# print(obl_miloss_a0_policy)
# print(obl_mi_miloss_a0_policy)

CB91_Blue = '#2CBDFE'
CB91_Green = '#47DBCD'
CB91_Pink = '#F3A0F2'
CB91_Purple = '#9D2EC5'
CB91_Violet = '#661D98'
CB91_Amber = '#F5B14C'
seventh_color = "#4b97ec"
color_list = [CB91_Blue, CB91_Pink, CB91_Green, CB91_Amber, CB91_Purple, CB91_Violet, seventh_color]

bar_names = ["Left", "Right", "Up", "Down", "No-Op", "Hint-Up", "Hint-Down"]

fig, axes = plt.subplots(4, 2)

axes[0, 0].bar(bar_names, iql_a0_policy.detach().numpy().squeeze(), color = color_list)
axes[0, 0].set_ylabel("Action Probability")
axes[0, 0].set_xlabel("Action")
axes[0, 0].set_title("IQL")
axes[0, 0].set_ylim(0, 0.35)
axes[0, 0].spines['top'].set_visible(False)
axes[0, 0].spines['right'].set_visible(False)

axes[0, 1].bar(bar_names, iql_ir_a0_policy.detach().numpy().squeeze(), color = color_list)
axes[0, 1].set_ylabel("Action Probability")
axes[0, 1].set_xlabel("Action")
axes[0, 1].set_title("IQL + IR")
axes[0, 1].set_ylim(0, 0.35)
axes[0, 1].spines['top'].set_visible(False)
axes[0, 1].spines['right'].set_visible(False)

axes[1, 0].bar(bar_names, obl_a0_policy.detach().numpy().squeeze(), color = color_list)
axes[1, 0].set_ylabel("Action Probability")
axes[1, 0].set_xlabel("Action")
axes[1, 0].set_title("OBL")
axes[1, 0].set_ylim(0, 0.35)
axes[1, 0].spines['top'].set_visible(False)
axes[1, 0].spines['right'].set_visible(False)

axes[1, 1].bar(bar_names, obl_ir_a0_policy.detach().numpy().squeeze(), color = color_list)
axes[1, 1].set_ylabel("Action Probability")
axes[1, 1].set_xlabel("Action")
axes[1, 1].set_title("OBL + IR")
axes[1, 1].set_ylim(0, 0.35)
axes[1, 1].spines['top'].set_visible(False)
axes[1, 1].spines['right'].set_visible(False)

axes[2, 0].bar(bar_names, obl_mi_a0_policy.detach().numpy().squeeze(), color = color_list)
axes[2, 0].set_ylabel("Action Probability")
axes[2, 0].set_xlabel("Action")
axes[2, 0].set_title("OBL + MI reward")
axes[2, 0].set_ylim(0, 0.35)
axes[2, 0].spines['top'].set_visible(False)
axes[2, 0].spines['right'].set_visible(False)

axes[2, 1].bar(bar_names, obl_miloss_a0_policy.detach().numpy().squeeze(), color = color_list)
axes[2, 1].set_ylabel("Action Probability")
axes[2, 1].set_xlabel("Action")
axes[2, 1].set_title("OBL + MI loss")
axes[2, 1].set_ylim(0, 0.35)
axes[2, 1].spines['top'].set_visible(False)
axes[2, 1].spines['right'].set_visible(False)

axes[3, 0].bar(bar_names, obl_mi_miloss_a0_policy.detach().numpy().squeeze(), color = color_list)
axes[3, 0].set_ylabel("Action Probability")
axes[3, 0].set_xlabel("Action")
axes[3, 0].set_title("OBL + MI reward + MI loss", fontweight="bold")
axes[3, 0].set_ylim(0, 0.35)
axes[3, 0].spines['top'].set_visible(False)
axes[3, 0].spines['right'].set_visible(False)

fig.delaxes(axes[3, 1])

plt.ylim([0.0, 0.4])
fig.suptitle("Sender's policy at functional phone booth")
plt.show()
